#!/usr/bin/env python3
import os,re
import numpy as np
from datetime import datetime,timedelta
import pendulum
import pandas as pd
from matplotlib.path import Path
import xarray as xr
import proplot as plot
#plot.rc['backend'] = 'Qt4Agg'
print(plot.rc['backend'])
import cartopy.crs as ccrs
from cartopy.io.shapereader import Reader
from cartopy.mpl.patch import geos_to_path
from cartopy.feature import ShapelyFeature
from shapely.geometry import MultiPolygon
import cmaps
import argparse
import concurrent.futures
from xml.dom import minidom

def parse_time(string):
    match = re.match(r'(\d{4}\d{2}\d{2})(\d{2})?(\d{2})?', string)
    if match.group(1) and match.group(2) and match.group(3):
        return pendulum.from_format(string, 'YYYYMMDDHHmm')
    if match.group(1) and match.group(2):
        return pendulum.from_format(string, 'YYYYMMDDHH')
    elif match.group(1):
        return pendulum.from_format(string, 'YYYYMMDD')

def parsexml(filename):
    print(f'[Notice]: Reading {filename}')
    xmldoc = minidom.parse(filename)
    itemlist = xmldoc.getElementsByTagName('R')
    ele = []
    for s in itemlist:
        ele.append( [ datetime( int(s.attributes['Year'].value),int(s.attributes['Mon'].value), \
            int(s.attributes['Day'].value), int(s.attributes['Hour'].value), int(s.attributes['Min'].value)), \
            str(s.attributes['Station_Id_C'].value), float(s.attributes['Lat'].value), float(s.attributes['Lon'].value), \
            float(s.attributes['TEM'].value), float(s.attributes['RHU'].value), float(s.attributes['PRS'].value) ] )
    # print(type(s.attributes['Hour'].value))

    return pd.DataFrame(ele, columns = ['Time', 'Station_Id_C', 'Lat', 'Lon', 'TEM', 'RHU', 'PRS'])




def plot_sfc(pcp, df, fig_dir, fig_format):
    # load china shapefile
    shp_file = f'{os.path.dirname(os.path.realpath(__file__))}/shapefiles/china.shp'
    shape_records = Reader(shp_file).records()
    chn_geoms = []
    for country in shape_records:
        name = country.attributes['FCNAME'].rstrip('\x00')
        chn_geoms += [country.geometry]
        if name == '甘肃省':
            geoms  = [ country.geometry ]
            gs_geoms = MultiPolygon([country.geometry])
            path   = Path.make_compound_path(*geos_to_path(geoms))
    chn_geoms = MultiPolygon(chn_geoms)
    # chinese font
    cnfont = {'fontname':'fangsong'}
    # plot TEM
    df_var = df[ ( df['TEM'] != 999999.0 ) & ( df['TEM'] != 999998) ]
    df_var = df_var.sort_values(by='TEM')
    #print(df_var[ df_var['TEM'] > 40 ])
    f, axs = plot.subplots(ncols=3, nrows=1, figsize=(24,8),proj='pcarree' )
    axs.format(
        labels=True,latlines=10, lonlines=10,lonlim=(ds.lon.values[0],ds.lon.values[-1]),latlim=(ds.lat.values[0],ds.lat.values[-1]),
        suptitle='2米温度($^\circ$C)  {}'.format(pd.to_datetime(ds.time.values).strftime('%Y-%m-%dT%H:%M')),**cnfont
    )
 
    levels   = np.linspace(-20,50,141)
    im = axs[0].contourf(ds.lon, ds.lat, ds.ts - 273.15, levels=levels, cmap=cmaps.BlAqGrYeOrReVi200, ) # precip3_16lev
    axs[0].colorbar(im, loc='b', length=0.9)
    axs[0].format(title='分析', **cnfont)

    im = axs[1].scatter(df_var['Lon'].values, df_var['Lat'].values, marker='o', c=df_var['TEM'].values, s=2, lw=0.01, cmap=cmaps.BlAqGrYeOrReVi200, levels=levels,edgecolors='k')
    axs[1].colorbar(im, loc='b', length=0.9,)
    axs[1].format(title='观测', **cnfont)
   
    x = xr.DataArray(df_var['Lon'], dims="station")
    y = xr.DataArray(df_var['Lat'], dims="station")
    ds_interp = ds.ts.interp(lon=x, lat=y, method='linear' ) - 273.15
    im = axs[2].scatter(df_var['Lon'].values, df_var['Lat'].values, marker='o', c=df_var['TEM'].values - ds_interp, s=2, lw=0.01, cmap=cmaps.BlWhRe, levels=np.linspace(-5,5,21),edgecolors='k')
    axs[2].colorbar(im, loc='b', length=0.9, extend='both')
    axs[2].format(title='观测-分析(o-a)', **cnfont)

    shape_feature = ShapelyFeature(chn_geoms, ccrs.PlateCarree(), facecolor='none',edgecolor='k')
    axs.add_feature(shape_feature)

    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    _fig_path = os.path.join(fig_dir,'2t_{}'.format(pd.to_datetime(ds.time.values).strftime('%Y%m%d%H%M')))
    #plot.show()
    f.savefig(_fig_path + '.' + fig_format, dpi=200, bbox_inches='tight')
    
    # plot RHU
    df_var = df[ ( df['RHU'] != 999999.0 ) & ( df['RHU'] != 999998) ]
    df_var = df_var.sort_values(by='RHU',ascending=False)
    f, axs = plot.subplots(ncols=3, nrows=1, figsize=(24,8),proj='pcarree' )
    axs.format(
        labels=True,latlines=10, lonlines=10,lonlim=(ds.lon.values[0],ds.lon.values[-1]),latlim=(ds.lat.values[0],ds.lat.values[-1]),
        suptitle='2米湿度(%)  {}'.format(pd.to_datetime(ds.time.values).strftime('%Y-%m-%dT%H:%M')),**cnfont
    )
 
    levels   = np.linspace(0,100,41)
    im = axs[0].contourf(ds.lon, ds.lat, ds.rhs, levels=levels, cmap=cmaps.WhBlGrYeRe, ) # precip3_16lev
    axs[0].colorbar(im, loc='b', length=0.9)
    axs[0].format(title='分析', **cnfont)

    im = axs[1].scatter(df_var['Lon'].values, df_var['Lat'].values, marker='o', c=df_var['RHU'].values, s=2, lw=0.01, cmap=cmaps.WhBlGrYeRe, levels=levels,edgecolors='k')
    axs[1].colorbar(im, loc='b', length=0.9,)
    axs[1].format(title='观测', **cnfont)

    x = xr.DataArray(df_var['Lon'], dims="station")
    y = xr.DataArray(df_var['Lat'], dims="station")
    ds_interp = ds.rhs.interp(lon=x, lat=y, method='linear' )
    im = axs[2].scatter(df_var['Lon'].values, df_var['Lat'].values, marker='o', c=df_var['RHU'].values - ds_interp, s=2, lw=0.01, cmap=cmaps.BlWhRe, levels=np.linspace(-20,20,21),edgecolors='k')
    axs[2].colorbar(im, loc='b', length=0.9, extend='both')
    axs[2].format(title='观测-分析(o-a)', **cnfont)

    shape_feature = ShapelyFeature(chn_geoms, ccrs.PlateCarree(), facecolor='none',edgecolor='k')
    axs.add_feature(shape_feature)

    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    _fig_path = os.path.join(fig_dir,'2rh_{}'.format(pd.to_datetime(ds.time.values).strftime('%Y%m%d%H%M')))
    #plot.show()
    f.savefig(_fig_path + '.' + fig_format, dpi=200, bbox_inches='tight')
    
    # plot PRS
    df_var = df[ ( df['PRS'] != 999999.0 ) & ( df['PRS'] != 999998) ]
    df_var = df_var.sort_values(by='PRS',ascending=False)
    f, axs = plot.subplots(ncols=3, nrows=1, figsize=(24,8),proj='pcarree' )
    axs.format(
        labels=True,latlines=10, lonlines=10,lonlim=(ds.lon.values[0],ds.lon.values[-1]),latlim=(ds.lat.values[0],ds.lat.values[-1]),
        suptitle='地面气压(hPa)  {}'.format(pd.to_datetime(ds.time.values).strftime('%Y-%m-%dT%H:%M')),**cnfont
    )
 
    levels   = np.linspace(500,1100,61)
    im = axs[0].contourf(ds.lon, ds.lat, ds.ps / 100.0, levels=levels, cmap=cmaps.nice_gfdl, ) # precip3_16lev
    axs[0].colorbar(im, loc='b', length=0.9)
    axs[0].format(title='分析', **cnfont)

    im = axs[1].scatter(df_var['Lon'].values, df_var['Lat'].values, marker='o', c=df_var['PRS'].values, s=2, lw=0.01, cmap=cmaps.nice_gfdl, levels=levels,edgecolors='k')
    axs[1].colorbar(im, loc='b', length=0.9,)
    axs[1].format(title='观测', **cnfont)
    
    x = xr.DataArray(df_var['Lon'], dims="station")
    y = xr.DataArray(df_var['Lat'], dims="station")
    ds_interp = ds.ps.interp(lon=x, lat=y, method='linear' ) /100
    im = axs[2].scatter(df_var['Lon'].values, df_var['Lat'].values, marker='o', c=df_var['PRS'].values - ds_interp, s=2, lw=0.01, cmap=cmaps.BlWhRe, levels=np.linspace(-20,20,41),edgecolors='k')
    axs[2].colorbar(im, loc='b', length=0.9, extend='both')
    axs[2].format(title='观测-分析(o-a)', **cnfont)

    shape_feature = ShapelyFeature(chn_geoms, ccrs.PlateCarree(), facecolor='none',edgecolor='k')
    axs.add_feature(shape_feature)

    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    _fig_path = os.path.join(fig_dir,'sp_{}'.format(pd.to_datetime(pcp.time.values).strftime('%Y%m%d%H%M')))
    #plot.show()
    f.savefig(_fig_path + '.' + fig_format, dpi=200, bbox_inches='tight')

    return



if __name__ == '__main__':
    '''
    plot moc laps products ...
    '''
    parser = argparse.ArgumentParser(description='plot moc laps surface data.')
    parser.add_argument('-o', '--root-dir', dest='root_dir', default='/data/cma_moc/gfs-3km-prod', help='Root directory to store data.')
    parser.add_argument('--fig-dir', dest='fig_dir', default='/data/cma_moc/rtoas-figure/gfs-3km-prod', help='Root directory to store figure.')
    parser.add_argument('--output-prefix', dest='output_prefix', default='MOC_3KM_SFC', help='filename prefix.')
    parser.add_argument('-t', '--time', help='file time (YYYYMMDDHH[MM]).', type=parse_time)
    parser.add_argument('-f','--format', dest='format', default='png', help='figure format.')
    args = parser.parse_args()

    file_path = os.path.join(args.root_dir,args.time.format('YYYY'),args.time.format('YYYYMMDD'),f'{args.output_prefix}_{args.time.format("YYYYMMDDHHmm")}.nc')
    obs_path = os.path.join('/nas02/data/raw/cimiss/SURF_CHN_MAIN_MIN/',f'{args.time.format("YYYYMMDDHHmm")}.xml')
    # /nas02/data/raw/cimiss_archive/SURF_CHN_MUL_HOR/20201130/202011300300.xml
    df = parsexml(obs_path)
    try:
        ds = xr.open_dataset(file_path)
        ds.load()
    except:
        print(f'[Error]: not found {file_path}')
        exit(1)
    ds = ds.isel(time=0)
    plot_sfc(ds, df, os.path.join(args.fig_dir,args.time.format('YYYY'),args.time.format('YYYYMMDD')), args.format)
